# -*- coding: utf-8 -*-
"""
Created on Wed Mar  1 10:40:21 2024
@author: xinyuan.wei
Updated 06/27/2024
FIA Data Summarize and Extract
"""
import os
import pandas as pd
import numpy as np

#########################################################################
### Input information and load data ###
#########################################################################
# Get the current working directory
current_directory = os.getcwd()

# Get the parent directory
parent_directory = os.path.dirname(current_directory)

# State FIA data
state = 'VT'
filename = state + '_CSV'
file_path = os.path.join(parent_directory, 'FIA_Data', filename)

# Save state data
savename = state + '_Plot_Composition.csv'
save_path = os.path.join(parent_directory, 'FIA_Plot_Composition')
savedata = os.path.join(parent_directory, 'FIA_Plot_Composition', savename)
mergedfile = 'NE_Plot_Composition.csv'

# Load the Tree, Condition, Plot data
tree_data = pd.read_csv(file_path + '/'+ state + '_TREE.csv', low_memory=False)
plot_data = pd.read_csv(file_path + '/'+ state + '_PLOT.csv')
cond_data = pd.read_csv(file_path + '/'+ state + '_COND.csv')

# Inventory time period
ymin = 2002
ymax = 2021

#########################################################################
### Function to extract the state plot forest composition data ###
#########################################################################                               
def plot_species(tree_data, plot_data, cond_data, savefile):
    # Filter the data after 2000
    tree_data = tree_data[(tree_data['INVYR'] >= ymin) & (tree_data['INVYR'] <= ymax)]
    
    # Define and retain necessary columns for tree data
    columns_keep = ["INVYR", "STATECD", "COUNTYCD", "PLOT", "SUBP", "TREE", 
                    "CONDID", "STATUSCD", "SPCD", "DIA", "HT", "TPA_UNADJ",
                    "CARBON_AG", "CARBON_BG"]
    tree_data = tree_data.loc[:, columns_keep]
    
    # Calculate the stem area, biomass for each plot (Biomass: lb/ha to kg/m2)
    unitsf = 0.000112085
    tree_data['STEM_AREA'] = (tree_data['DIA'] / 2) ** 2 * np.pi
    tree_data['eAG'] = tree_data['CARBON_AG'] * tree_data['TPA_UNADJ'] * unitsf
    tree_data['eBG'] = tree_data['CARBON_BG'] * tree_data['TPA_UNADJ'] * unitsf
    
    # Handle NaN values in STEM_AREA by filling them with zero
    tree_data['STEM_AREA'].fillna(0, inplace=True)
    
    # Calculate the Shannon-Wiener Index for each plot and inventory year
    sindex = tree_data.groupby(['INVYR', 'PLOT']).apply(
        lambda df: -np.sum((df['SPCD'].value_counts(normalize=True)) * 
                           np.log(df['SPCD'].value_counts(normalize=True)))
    ).reset_index(name='sindex')
    
    # Calculate the Pielou’s Evenness Index for each plot and inventory year
    def pielou_evenness(df):
        species_counts = df['SPCD'].value_counts()
        num_species = len(species_counts)
        if num_species > 1:
            shannon_index = -np.sum((species_counts / species_counts.sum()) * np.log(species_counts / species_counts.sum()))
            pielou_index = shannon_index / np.log(num_species)
        else:
            pielou_index = np.nan  # Pielou's Evenness is undefined for a single species
        return pielou_index
    
    evenness = tree_data.groupby(['INVYR', 'PLOT']).apply(pielou_evenness).reset_index(name='evenness')
    
    # Aggregate the biomass data
    plot_biomass = tree_data.groupby(['INVYR', 'PLOT'])[['eAG', 'eBG']].sum().reset_index()
    
    # Determine the dominant tree species for each plot based on the number of trees
    dominant_count_species = tree_data.groupby(
        ['INVYR', 'PLOT'])['SPCD'].agg(lambda x: x.value_counts().idxmax()).reset_index(name='dominant_count_species')
    
    # Determine the dominant tree species for each plot based on stem area
    dominant_area_species = tree_data.groupby(['INVYR', 'PLOT']).apply(
        lambda df: df.loc[df['STEM_AREA'].idxmax()]['SPCD'] if df['STEM_AREA'].max() > 0 else np.nan
    ).reset_index(name='dominant_area_species')
    
    # Merge all results together
    results = pd.merge(sindex, plot_biomass, on=['INVYR', 'PLOT'])
    results = pd.merge(results, dominant_count_species, on=['INVYR', 'PLOT'])
    results = pd.merge(results, dominant_area_species, on=['INVYR', 'PLOT'])
    results = pd.merge(results, evenness, on=['INVYR', 'PLOT'])
    
    # Merge latitude and longitude from plot data
    results = pd.merge(results, plot_data[['INVYR', 'PLOT', 'LAT', 'LON', 'ELEV']], 
                       on=['INVYR', 'PLOT'], how='left')

    # Merge plot condition data from condition data
    results = pd.merge(results, cond_data[['INVYR', 'PLOT', 'STDAGE', 'CONDID', 
                                           'CONDPROP_UNADJ', 'DSTRBCD1', 'TRTCD1', 'TRTYR1']], 
                       on=['INVYR', 'PLOT'], how='left')
    
    # Merge additional required columns from tree_data
    add_cols = tree_data[['INVYR', 'COUNTYCD', 'PLOT']].drop_duplicates()
    
    results = pd.merge(results, add_cols, on=['INVYR', 'PLOT'], how='left')
    
    # Sort the results to match the original data order
    results = results.sort_values(by=['INVYR', 'PLOT'])
    
    # Save the results
    results = results[['INVYR', 'STDAGE', 'COUNTYCD', 'PLOT', 'LAT', 'LON', 'ELEV',
                       'CONDID', 'CONDPROP_UNADJ', 'DSTRBCD1', 'TRTCD1','TRTYR1',
                       'sindex', 'evenness', 'eAG','eBG', 'dominant_count_species', 
                       'dominant_area_species']]
    results.to_csv(savefile, index=False)

    return results

# Call the function
#plot_species(tree_data, plot_data, cond_data, savedata)

#########################################################################
### Function to merge plot forest composition data ###
#########################################################################  
def merge_csv_files(directory):
    # List all CSV files in the directory
    csv_files = [f for f in os.listdir(directory) if f.endswith('.csv')]
    
    # Initialize an empty list to store DataFrames
    data_frames = []

    for file in csv_files:
        # Construct the full file path
        file_path = os.path.join(directory, file)
        
        # Extract the state information from the file name
        state = file.split('_')[0]
        
        # Read the CSV file into a DataFrame
        df = pd.read_csv(file_path)
        
        # Add the state information to the 'PLOT' field
        df['PLOT'] = state + df['PLOT'].astype(str)
        
        # Append the DataFrame to the list
        data_frames.append(df)
    
    # Concatenate all DataFrames into a single DataFrame
    merged_df = pd.concat(data_frames, ignore_index=True)
    
    return merged_df

# Call the function
#merged_data = merge_csv_files(save_path)
#merged_data.to_csv(os.path.join(parent_directory, mergedfile), index=False)

